package com.icesoft.system.safe.config;

import com.fasterxml.jackson.core.type.TypeReference;
import com.icesoft.framework.core.util.JSON;
import com.icesoft.system.safe.helper.RequestHelper;
import com.icesoft.system.safe.service.SafeRequestCryptoService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.*;

/**
 * ContentCachingRequestWrapper
 */
@Slf4j
public class SafeHttpServletRequest extends HttpServletRequestWrapper {

	public static final String ATTR_NAME = "SafeHttpServletRequest";
	private final SafeRequestCryptoService safeRequestCryptoService;
	private final boolean isGetMethod;
	private final HttpServletRequest request;
	private TreeMap<String, String> paramMap;
	private ServletInputStream inputStream;
	private BufferedReader reader;
	private String body;

	public SafeHttpServletRequest(SafeRequestCryptoService safeRequestCryptoService, HttpServletRequest request) {
		super(request);
		this.request = request;
		this.safeRequestCryptoService = safeRequestCryptoService;
		isGetMethod = "GET".equals(request.getMethod());
		request.setAttribute(ATTR_NAME, this);
	}

	public String getBody() throws IOException {
		if (body != null) {
			return body;
		}
		byte[] data = IOUtils.toByteArray(request.getInputStream());
		inputStream = new RequestCachingInputStream(data);
		body = new String(data, StandardCharsets.UTF_8);
		if (body.startsWith("\"")) {
			body = body.substring(1);
		}
		if (body.endsWith("\"")) {
			body = body.substring(0, body.length() - 1);
		}
		return body;
	}

	public TreeMap<String, String> resolveParam()
			throws IOException {
		if (paramMap != null) {
			return paramMap;
		}
		if (isGetMethod) {
			paramMap = new TreeMap<>(RequestHelper.getRequestParamMap(request));
			return paramMap;
		}
		String paramStr = getBody();
		if (StringUtils.isBlank(paramStr)) {
			paramMap = new TreeMap<>(RequestHelper.getRequestParamMap(request));
			return paramMap;
		}
		if (log.isTraceEnabled()) {
			log.trace("解密类: {}", safeRequestCryptoService.getClass());
		}
		paramStr = safeRequestCryptoService.decryptReqBody(paramStr);
		inputStream = new RequestCachingInputStream(paramStr.getBytes(StandardCharsets.UTF_8));
		log.trace("请求json：{}", paramStr);
		if (StringUtils.isNotBlank(paramStr)) {
			paramMap = JSON.parseObject(paramStr, new TypeReference<TreeMap<String, String>>() {
			});
		} else {
			paramMap = new TreeMap<>();
		}
		return paramMap;
	}

	@Override
	public BufferedReader getReader() throws IOException {
		if (reader == null && inputStream != null) {
			reader = new BufferedReader(new InputStreamReader(inputStream, getCharacterEncoding()));
		}
		return super.getReader();
	}

	@Override
	public ServletInputStream getInputStream() throws IOException {
		if (inputStream != null) {
			return inputStream;
		}
		return super.getInputStream();
	}

	@Override
	public String getParameter(String name) {
		if (paramMap != null) {
			return paramMap.get(name);
		}
		return super.getParameter(name);
	}

	@Override
	public String[] getParameterValues(String name) {
		if (paramMap != null) {
			return new String[]{paramMap.get(name)};
		}
		return super.getParameterValues(name);
	}

	@Override
	public Map<String, String[]> getParameterMap() {
		if (paramMap != null) {
			Map<String, String[]> map = new HashMap<>();
			for (Map.Entry<String, String> entry : paramMap.entrySet()) {
				map.put(entry.getKey(), new String[]{entry.getValue()});
			}
			return map;
		}
		return super.getParameterMap();
	}

	@Override
	public Enumeration<String> getParameterNames() {
		if (paramMap != null) {
			return Collections.enumeration(paramMap.keySet());
		}
		return super.getParameterNames();
	}

	private static class RequestCachingInputStream extends ServletInputStream {

		private final ByteArrayInputStream inputStream;

		public RequestCachingInputStream(byte[] bytes) {
			inputStream = new ByteArrayInputStream(bytes);
		}

		@Override
		public int read() throws IOException {
			return inputStream.read();
		}

		@Override
		public boolean isFinished() {
			return inputStream.available() == 0;
		}

		@Override
		public boolean isReady() {
			return true;
		}

		@Override
		public void setReadListener(ReadListener readlistener) {
		}

	}

}
